{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "cb5a6489",
      "metadata": {
        "id": "cb5a6489"
      },
      "source": [
        "# Document-level Relation Extraction Tutorial\n",
        "\n",
        "> Tutorial author: 黎洲波（zhoubo.li@zju.edu.cn）\n",
        "\n",
        "In this tutorial, we use [DocuNet](http://arxiv.org/abs/2106.03618) to extract relational triples in different sentences. We hope this tutorial can help you understand the process of document-level relation extraction.\n",
        "\n",
        "This tutorial uses `Python3`."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9e00a908",
      "metadata": {
        "id": "9e00a908"
      },
      "source": [
        "## RE\n",
        "**Relation extraction** (RE), a key task in information extraction, predicts semantic relations between pairs of entities from unstructured\n",
        "texts.\n",
        "\n",
        "## Document-level RE\n",
        "Document-level RE extracts relations from multi-sentence in one document. An example is shown in the following picture, in which named entities are annotated with colors. Different from sentence-level RE, document-level RE can extract both intra-sentence and inter-sentence relational triples.\n",
        "![文档级关系抽取](img/img1.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "09f8d2f0",
      "metadata": {
        "id": "09f8d2f0"
      },
      "source": [
        "## Dataset\n",
        "\n",
        "There are some document-level RE datasets including DocRED, CDR and GDA. The tutorial uses [DocRED](https://github.com/thunlp/DocRED/tree/master/). The structure of the dataset folder `./data/` is as follow:\n",
        "\n",
        "```\n",
        ".\n",
        "├── dev.json                        # Validation Set\n",
        "├── rel_info.json                   # Relation Label\n",
        "├── rel2id.json                     # Relation Label - ID Map\n",
        "├── test.json                       # Test Set\n",
        "└── train_annotated.json            # Training Set\n",
        "```\n",
        "\n",
        "The data formats of DocRED are described as follow:\n",
        "\n",
        "```\n",
        "Data Format:\n",
        "{\n",
        "  'title',\n",
        "  'sents':     [\n",
        "                  [word in sent 0],\n",
        "                  [word in sent 1]\n",
        "               ]\n",
        "  'vertexSet': [\n",
        "                  [\n",
        "                    { 'name': mention_name, \n",
        "                      'sent_id': mention in which sentence, \n",
        "                      'pos': postion of mention in a sentence, \n",
        "                      'type': NER_type}\n",
        "                    {anthor mention}\n",
        "                  ], \n",
        "                  [anthoer entity]\n",
        "                ]\n",
        "  'labels':   [\n",
        "                {\n",
        "                  'h': idx of head entity in vertexSet,\n",
        "                  't': idx of tail entity in vertexSet,\n",
        "                  'r': relation,\n",
        "                  'evidence': evidence sentences' id\n",
        "                }\n",
        "              ]\n",
        "}\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "51636aba",
      "metadata": {
        "id": "51636aba"
      },
      "source": [
        "## DocuNet\n",
        "- [DocuNet](http://arxiv.org/abs/2106.03618) used in DeepKE is a semantic segmentation method using Document U-shaped Network based on computer vision (CV) and obtains excellent performance on DocRED dataset.\n",
        "- The framework of DocuNet is as follow:\n",
        "\n",
        "![文档级关系抽取架构图](img/img2.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "aee2bb9f",
      "metadata": {
        "id": "aee2bb9f"
      },
      "source": [
        "## Prepare the runtime environment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "63f859de",
      "metadata": {
        "id": "63f859de"
      },
      "outputs": [],
      "source": [
        "!pip install deepke\n",
        "!wget 120.27.214.45/Data/re/document/data.tar.gz\n",
        "!tar -xzvf data.tar.gz"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "fb01b613",
      "metadata": {
        "id": "fb01b613"
      },
      "source": [
        "## Import modules"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6dd0e029",
      "metadata": {
        "id": "6dd0e029"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import time\n",
        "import numpy as np\n",
        "import torch\n",
        "import random\n",
        "import pickle\n",
        "from tqdm import tqdm\n",
        "import ujson as json\n",
        "from opt_einsum import contract\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import AutoConfig, AutoModel, AutoTokenizer\n",
        "from transformers.optimization import AdamW, get_linear_schedule_with_warmup"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "2a1b4764",
      "metadata": {
        "id": "2a1b4764"
      },
      "source": [
        "## Preprocess the dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bf72cdc3",
      "metadata": {
        "id": "bf72cdc3"
      },
      "outputs": [],
      "source": [
        "rel2id = json.load(open('./data/rel2id.json', 'r'))\n",
        "id2rel = {value: key for key, value in rel2id.items()}\n",
        "\n",
        "\n",
        "def chunks(l, n):\n",
        "    res = []\n",
        "    for i in range(0, len(l), n):\n",
        "        assert len(l[i:i + n]) == n\n",
        "        res += [l[i:i + n]]\n",
        "    return res\n",
        "\n",
        "class ReadDataset:\n",
        "    def __init__(self, dataset: str, tokenizer, max_seq_Length: int = 1024,\n",
        "             transformers: str = 'bert') -> None:\n",
        "        self.transformers = transformers\n",
        "        self.dataset = dataset\n",
        "        self.tokenizer = tokenizer\n",
        "        self.max_seq_Length = max_seq_Length\n",
        "\n",
        "    def read(self, file_in: str):\n",
        "        save_file = file_in.split('.json')[0] + '_' + self.transformers + '_' \\\n",
        "                        + self.dataset + '.pkl'\n",
        "        if self.dataset == 'docred':\n",
        "            return read_docred(self.transformers, file_in, save_file, self.tokenizer, self.max_seq_Length)\n",
        "        else:\n",
        "            raise RuntimeError(\"No read func for this dataset.\")\n",
        "\n",
        "def read_docred(transfermers, file_in, save_file, tokenizer, max_seq_length=1024):\n",
        "    if os.path.exists(save_file):\n",
        "        with open(file=save_file, mode='rb') as fr:\n",
        "            features = pickle.load(fr)\n",
        "            fr.close()\n",
        "        print('load preprocessed data from {}.'.format(save_file))\n",
        "        return features\n",
        "    else:\n",
        "        max_len = 0\n",
        "        up512_num = 0\n",
        "        i_line = 0\n",
        "        pos_samples = 0\n",
        "        neg_samples = 0\n",
        "        features = []\n",
        "        if file_in == \"\":\n",
        "            return None\n",
        "        with open(file_in, \"r\") as fh:\n",
        "            data = json.load(fh)\n",
        "        if transfermers == 'bert':\n",
        "            # entity_type = [\"ORG\", \"-\",  \"LOC\", \"-\",  \"TIME\", \"-\",  \"PER\", \"-\", \"MISC\", \"-\", \"NUM\"]\n",
        "            entity_type = [\"-\", \"ORG\", \"-\",  \"LOC\", \"-\",  \"TIME\", \"-\",  \"PER\", \"-\", \"MISC\", \"-\", \"NUM\"]\n",
        "\n",
        "\n",
        "        for sample in tqdm(data, desc=\"Example\"):\n",
        "            sents = []\n",
        "            sent_map = []\n",
        "\n",
        "            entities = sample['vertexSet']\n",
        "            entity_start, entity_end = [], []\n",
        "            mention_types = []\n",
        "            for entity in entities:\n",
        "                for mention in entity:\n",
        "                    sent_id = mention[\"sent_id\"]\n",
        "                    pos = mention[\"pos\"]\n",
        "                    entity_start.append((sent_id, pos[0]))\n",
        "                    entity_end.append((sent_id, pos[1] - 1))\n",
        "                    mention_types.append(mention['type'])\n",
        "\n",
        "            for i_s, sent in enumerate(sample['sents']):\n",
        "                new_map = {}\n",
        "                for i_t, token in enumerate(sent):\n",
        "                    tokens_wordpiece = tokenizer.tokenize(token)\n",
        "                    if (i_s, i_t) in entity_start:\n",
        "                        t = entity_start.index((i_s, i_t))\n",
        "                        if transfermers == 'bert':\n",
        "                            mention_type = mention_types[t]\n",
        "                            special_token_i = entity_type.index(mention_type)\n",
        "                            special_token = ['[unused' + str(special_token_i) + ']']\n",
        "                        else:\n",
        "                            special_token = ['*']\n",
        "                        tokens_wordpiece = special_token + tokens_wordpiece\n",
        "                        # tokens_wordpiece = [\"[unused0]\"]+ tokens_wordpiece\n",
        "\n",
        "                    if (i_s, i_t) in entity_end:\n",
        "                        t = entity_end.index((i_s, i_t))\n",
        "                        if transfermers == 'bert':\n",
        "                            mention_type = mention_types[t]\n",
        "                            special_token_i = entity_type.index(mention_type) + 50\n",
        "                            special_token = ['[unused' + str(special_token_i) + ']']\n",
        "                        else:\n",
        "                            special_token = ['*']\n",
        "                        tokens_wordpiece = tokens_wordpiece + special_token\n",
        "\n",
        "                        # tokens_wordpiece = tokens_wordpiece + [\"[unused1]\"]\n",
        "                        # print(tokens_wordpiece,tokenizer.convert_tokens_to_ids(tokens_wordpiece))\n",
        "\n",
        "                    new_map[i_t] = len(sents)\n",
        "                    sents.extend(tokens_wordpiece)\n",
        "                new_map[i_t + 1] = len(sents)\n",
        "                sent_map.append(new_map)\n",
        "\n",
        "            if len(sents)>max_len:\n",
        "                max_len=len(sents)\n",
        "            if len(sents)>512:\n",
        "                up512_num += 1\n",
        "\n",
        "            train_triple = {}\n",
        "            if \"labels\" in sample:\n",
        "                for label in sample['labels']:\n",
        "                    evidence = label['evidence']\n",
        "                    r = int(rel2id[label['r']])\n",
        "                    if (label['h'], label['t']) not in train_triple:\n",
        "                        train_triple[(label['h'], label['t'])] = [\n",
        "                            {'relation': r, 'evidence': evidence}]\n",
        "                    else:\n",
        "                        train_triple[(label['h'], label['t'])].append(\n",
        "                            {'relation': r, 'evidence': evidence})\n",
        "\n",
        "            entity_pos = []\n",
        "            for e in entities:\n",
        "                entity_pos.append([])\n",
        "                mention_num = len(e)\n",
        "                for m in e:\n",
        "                    start = sent_map[m[\"sent_id\"]][m[\"pos\"][0]]\n",
        "                    end = sent_map[m[\"sent_id\"]][m[\"pos\"][1]]\n",
        "                    entity_pos[-1].append((start, end,))\n",
        "\n",
        "\n",
        "            relations, hts = [], []\n",
        "            # Get positive samples from dataset\n",
        "            for h, t in train_triple.keys():\n",
        "                relation = [0] * len(rel2id)\n",
        "                for mention in train_triple[h, t]:\n",
        "                    relation[mention[\"relation\"]] = 1\n",
        "                    evidence = mention[\"evidence\"]\n",
        "                relations.append(relation)\n",
        "                hts.append([h, t])\n",
        "                pos_samples += 1\n",
        "\n",
        "            # Get negative samples from dataset\n",
        "            for h in range(len(entities)):\n",
        "                for t in range(len(entities)):\n",
        "                    if h != t and [h, t] not in hts:\n",
        "                        relation = [1] + [0] * (len(rel2id) - 1)\n",
        "                        relations.append(relation)\n",
        "                        hts.append([h, t])\n",
        "                        neg_samples += 1\n",
        "\n",
        "            assert len(relations) == len(entities) * (len(entities) - 1)\n",
        "\n",
        "            if len(hts)==0:\n",
        "                print(len(sent))\n",
        "            sents = sents[:max_seq_length - 2]\n",
        "            input_ids = tokenizer.convert_tokens_to_ids(sents)\n",
        "            input_ids = tokenizer.build_inputs_with_special_tokens(input_ids)\n",
        "\n",
        "            i_line += 1\n",
        "            feature = {'input_ids': input_ids,\n",
        "                       'entity_pos': entity_pos,\n",
        "                       'labels': relations,\n",
        "                       'hts': hts,\n",
        "                       'title': sample['title'],\n",
        "                       }\n",
        "            features.append(feature)\n",
        "\n",
        "\n",
        "\n",
        "        print(\"# of documents {}.\".format(i_line))\n",
        "        print(\"# of positive examples {}.\".format(pos_samples))\n",
        "        print(\"# of negative examples {}.\".format(neg_samples))\n",
        "        print(\"# {} examples len>512 and max len is {}.\".format(up512_num, max_len))\n",
        "\n",
        "\n",
        "        with open(file=save_file, mode='wb') as fw:\n",
        "            pickle.dump(features, fw)\n",
        "        print('finish reading {} and save preprocessed data to {}.'.format(file_in, save_file))\n",
        "\n",
        "        return features"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4f98868c",
      "metadata": {
        "id": "4f98868c"
      },
      "source": [
        "## Prepare the Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "eefee349",
      "metadata": {
        "id": "eefee349"
      },
      "outputs": [],
      "source": [
        "class AttentionUNet(torch.nn.Module):\n",
        "    \"\"\"\n",
        "    UNet, down sampling & up sampling for global reasoning\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, input_channels, class_number, **kwargs):\n",
        "        super(AttentionUNet, self).__init__()\n",
        "\n",
        "        down_channel = kwargs['down_channel'] # default = 256\n",
        "\n",
        "        down_channel_2 = down_channel * 2\n",
        "        up_channel_1 = down_channel_2 * 2\n",
        "        up_channel_2 = down_channel * 2\n",
        "\n",
        "        self.inc = InConv(input_channels, down_channel)\n",
        "        self.down1 = DownLayer(down_channel, down_channel_2)\n",
        "        self.down2 = DownLayer(down_channel_2, down_channel_2)\n",
        "\n",
        "        self.up1 = UpLayer(up_channel_1, up_channel_1 // 4)\n",
        "        self.up2 = UpLayer(up_channel_2, up_channel_2 // 4)\n",
        "        self.outc = OutConv(up_channel_2 // 4, class_number)\n",
        "\n",
        "    def forward(self, attention_channels):\n",
        "        \"\"\"\n",
        "        Given multi-channel attention map, return the logits of every one mapping into 3-class\n",
        "        :param attention_channels:\n",
        "        :return:\n",
        "        \"\"\"\n",
        "        # attention_channels as the shape of: batch_size x channel x width x height\n",
        "        x = attention_channels\n",
        "        x1 = self.inc(x)\n",
        "        x2 = self.down1(x1)\n",
        "        x3 = self.down2(x2)\n",
        "        x = self.up1(x3, x2)\n",
        "        x = self.up2(x, x1)\n",
        "        output = self.outc(x)\n",
        "        # attn_map as the shape of: batch_size x width x height x class\n",
        "        output = output.permute(0, 2, 3, 1).contiguous()\n",
        "        return output\n",
        "\n",
        "\n",
        "class DoubleConv(nn.Module):\n",
        "    \"\"\"(conv => [BN] => ReLU) * 2\"\"\"\n",
        "\n",
        "    def __init__(self, in_ch, out_ch):\n",
        "        super(DoubleConv, self).__init__()\n",
        "        self.double_conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\n",
        "                                         nn.BatchNorm2d(out_ch),\n",
        "                                         nn.ReLU(inplace=True),\n",
        "                                         nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\n",
        "                                         nn.BatchNorm2d(out_ch),\n",
        "                                         nn.ReLU(inplace=True))\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.double_conv(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "class InConv(nn.Module):\n",
        "\n",
        "    def __init__(self, in_ch, out_ch):\n",
        "        super(InConv, self).__init__()\n",
        "        self.conv = DoubleConv(in_ch, out_ch)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.conv(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "class DownLayer(nn.Module):\n",
        "\n",
        "    def __init__(self, in_ch, out_ch):\n",
        "        super(DownLayer, self).__init__()\n",
        "        self.maxpool_conv = nn.Sequential(\n",
        "            nn.MaxPool2d(kernel_size=2),\n",
        "            DoubleConv(in_ch, out_ch)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.maxpool_conv(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "class UpLayer(nn.Module):\n",
        "\n",
        "    def __init__(self, in_ch, out_ch, bilinear=True):\n",
        "        super(UpLayer, self).__init__()\n",
        "        if bilinear:\n",
        "            self.up = nn.Upsample(scale_factor=2, mode='bilinear',\n",
        "                                  align_corners=True)\n",
        "        else:\n",
        "            self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)\n",
        "        self.conv = DoubleConv(in_ch, out_ch)\n",
        "\n",
        "    def forward(self, x1, x2):\n",
        "        x1 = self.up(x1)\n",
        "        diffY = x2.size()[2] - x1.size()[2]\n",
        "        diffX = x2.size()[3] - x1.size()[3]\n",
        "        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY -\n",
        "                        diffY // 2))\n",
        "        x = torch.cat([x2, x1], dim=1)\n",
        "        x = self.conv(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "class OutConv(nn.Module):\n",
        "\n",
        "    def __init__(self, in_ch, out_ch):\n",
        "        super(OutConv, self).__init__()\n",
        "        self.conv = nn.Conv2d(in_ch, out_ch, 1)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.conv(x)\n",
        "        return x\n",
        "\n",
        "class DocREModel(nn.Module):\n",
        "    def __init__(self, config, args, model, emb_size=768, block_size=64, num_labels=-1):\n",
        "        super().__init__()\n",
        "        self.config = config\n",
        "        self.bert_model = model\n",
        "        self.hidden_size = config.hidden_size\n",
        "        self.loss_fnt = ATLoss()\n",
        "\n",
        "        self.head_extractor = nn.Linear(1 * config.hidden_size + args.unet_out_dim, emb_size)\n",
        "        self.tail_extractor = nn.Linear(1 * config.hidden_size + args.unet_out_dim, emb_size)\n",
        "        # self.head_extractor = nn.Linear(1 * config.hidden_size , emb_size)\n",
        "        # self.tail_extractor = nn.Linear(1 * config.hidden_size , emb_size)\n",
        "        self.bilinear = nn.Linear(emb_size * block_size, config.num_labels)\n",
        "\n",
        "        self.emb_size = emb_size\n",
        "        self.block_size = block_size\n",
        "        self.num_labels = num_labels\n",
        "\n",
        "        self.bertdrop = nn.Dropout(0.6)\n",
        "        self.unet_in_dim = args.unet_in_dim\n",
        "        self.unet_out_dim = args.unet_in_dim\n",
        "        self.liner = nn.Linear(config.hidden_size, args.unet_in_dim)\n",
        "        self.min_height = args.max_height\n",
        "        self.channel_type = args.channel_type\n",
        "        self.segmentation_net = AttentionUNet(input_channels=args.unet_in_dim,\n",
        "                                              class_number=args.unet_out_dim,\n",
        "                                              down_channel=args.down_dim)\n",
        "\n",
        "\n",
        "    def encode(self, input_ids, attention_mask,entity_pos):\n",
        "        config = self.config\n",
        "        if config.transformer_type == \"albert\":\n",
        "            start_tokens = [config.cls_token_id]\n",
        "            end_tokens = [config.sep_token_id]\n",
        "        elif config.transformer_type == \"roberta\":\n",
        "            start_tokens = [config.cls_token_id]\n",
        "            end_tokens = [config.sep_token_id, config.sep_token_id]\n",
        "        sequence_output, attention = process_long_input(self.bert_model, input_ids, attention_mask, start_tokens, end_tokens)\n",
        "        return sequence_output, attention\n",
        "\n",
        "    def get_hrt(self, sequence_output, attention, entity_pos, hts):\n",
        "        offset = 1 if self.config.transformer_type in [\"albert\", \"roberta\"] else 0\n",
        "        bs, h, _, c = attention.size()\n",
        "        # ne = max([len(x) for x in entity_pos])  # 本次bs中的最大实体数\n",
        "\n",
        "        hss, tss, rss = [], [], []\n",
        "        entity_es = []\n",
        "        entity_as = []\n",
        "        for i in range(len(entity_pos)):\n",
        "            entity_embs, entity_atts = [], []\n",
        "            for entity_num, e in enumerate(entity_pos[i]):\n",
        "                if len(e) > 1:\n",
        "                    e_emb, e_att = [], []\n",
        "                    for start, end in e:\n",
        "                        if start + offset < c:\n",
        "                            # In case the entity mention is truncated due to limited max seq length.\n",
        "                            e_emb.append(sequence_output[i, start + offset])\n",
        "                            e_att.append(attention[i, :, start + offset])\n",
        "                    if len(e_emb) > 0:\n",
        "                        e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)\n",
        "                        e_att = torch.stack(e_att, dim=0).mean(0)\n",
        "                    else:\n",
        "                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)\n",
        "                        e_att = torch.zeros(h, c).to(attention)\n",
        "                else:\n",
        "                    start, end = e[0]\n",
        "                    if start + offset < c:\n",
        "                        e_emb = sequence_output[i, start + offset]\n",
        "                        e_att = attention[i, :, start + offset]\n",
        "                    else:\n",
        "                        e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)\n",
        "                        e_att = torch.zeros(h, c).to(attention)\n",
        "                entity_embs.append(e_emb)\n",
        "                entity_atts.append(e_att)\n",
        "            for _ in range(self.min_height-entity_num-1):\n",
        "                entity_atts.append(e_att)\n",
        "\n",
        "            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]\n",
        "            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]\n",
        "\n",
        "\n",
        "            entity_es.append(entity_embs)\n",
        "            entity_as.append(entity_atts)\n",
        "            ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)\n",
        "            hs = torch.index_select(entity_embs, 0, ht_i[:, 0])\n",
        "            ts = torch.index_select(entity_embs, 0, ht_i[:, 1])\n",
        "\n",
        "            hss.append(hs)\n",
        "            tss.append(ts)\n",
        "        hss = torch.cat(hss, dim=0)\n",
        "        tss = torch.cat(tss, dim=0)\n",
        "        return hss, tss, entity_es, entity_as\n",
        "\n",
        "    def get_mask(self, ents, bs, ne, run_device):\n",
        "        ent_mask = torch.zeros(bs, ne, device=run_device)\n",
        "        rel_mask = torch.zeros(bs, ne, ne, device=run_device)\n",
        "        for _b in range(bs):\n",
        "            ent_mask[_b, :len(ents[_b])] = 1\n",
        "            rel_mask[_b, :len(ents[_b]), :len(ents[_b])] = 1\n",
        "        return ent_mask, rel_mask\n",
        "\n",
        "\n",
        "    def get_ht(self, rel_enco, hts):\n",
        "        htss = []\n",
        "        for i in range(len(hts)):\n",
        "            ht_index = hts[i]\n",
        "            for (h_index, t_index) in ht_index:\n",
        "                htss.append(rel_enco[i,h_index,t_index])\n",
        "        htss = torch.stack(htss,dim=0)\n",
        "        return htss\n",
        "\n",
        "    def get_channel_map(self, sequence_output, entity_as):\n",
        "        # sequence_output = sequence_output.to('cpu')\n",
        "        # attention = attention.to('cpu')\n",
        "        bs,_,d = sequence_output.size()\n",
        "        # ne = max([len(x) for x in entity_as])  # 本次bs中的最大实体数\n",
        "        ne = self.min_height\n",
        "\n",
        "        index_pair = []\n",
        "        for i in range(ne):\n",
        "            tmp = torch.cat((torch.ones((ne, 1), dtype=int) * i, torch.arange(0, ne).unsqueeze(1)), dim=-1)\n",
        "            index_pair.append(tmp)\n",
        "        index_pair = torch.stack(index_pair, dim=0).reshape(-1, 2).to(sequence_output.device)\n",
        "        map_rss = []\n",
        "        for b in range(bs):\n",
        "            entity_atts = entity_as[b]\n",
        "            h_att = torch.index_select(entity_atts, 0, index_pair[:, 0])\n",
        "            t_att = torch.index_select(entity_atts, 0, index_pair[:, 1])\n",
        "            ht_att = (h_att * t_att).mean(1)\n",
        "            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)\n",
        "            rs = contract(\"ld,rl->rd\", sequence_output[b], ht_att)\n",
        "            map_rss.append(rs)\n",
        "        map_rss = torch.cat(map_rss, dim=0).reshape(bs, ne, ne, d)\n",
        "        return map_rss\n",
        "\n",
        "    def forward(self,\n",
        "                input_ids=None,\n",
        "                attention_mask=None,\n",
        "                labels=None,\n",
        "                entity_pos=None,\n",
        "                hts=None,\n",
        "                instance_mask=None,\n",
        "                ):\n",
        "\n",
        "        sequence_output, attention = self.encode(input_ids, attention_mask,entity_pos)\n",
        "\n",
        "        bs, sequen_len, d = sequence_output.shape\n",
        "        run_device = sequence_output.device.index\n",
        "        ne = max([len(x) for x in entity_pos])  # 本次bs中的最大实体数\n",
        "        ent_mask, rel_mask = self.get_mask(entity_pos, bs, ne, run_device)\n",
        "\n",
        "        # get hs, ts and entity_embs >> entity_rs\n",
        "        hs, ts, entity_embs, entity_as = self.get_hrt(sequence_output, attention, entity_pos, hts)\n",
        "\n",
        "\n",
        "        if self.channel_type == 'context-based':\n",
        "            feature_map = self.get_channel_map(sequence_output, entity_as)\n",
        "            ##print('feature_map:', feature_map.shape)\n",
        "            attn_input = self.liner(feature_map).permute(0, 3, 1, 2).contiguous()\n",
        "\n",
        "        else:\n",
        "            raise Exception(\"channel_type must be specify correctly\")\n",
        "\n",
        "\n",
        "        attn_map = self.segmentation_net(attn_input)\n",
        "        h_t = self.get_ht (attn_map, hts)\n",
        "\n",
        "        hs = torch.tanh(self.head_extractor(torch.cat([hs, h_t], dim=1)))\n",
        "        ts = torch.tanh(self.tail_extractor(torch.cat([ts, h_t], dim=1)))\n",
        "\n",
        "\n",
        "        b1 = hs.view(-1, self.emb_size // self.block_size, self.block_size)\n",
        "        b2 = ts.view(-1, self.emb_size // self.block_size, self.block_size)\n",
        "        bl = (b1.unsqueeze(3) * b2.unsqueeze(2)).view(-1, self.emb_size * self.block_size)\n",
        "        logits = self.bilinear(bl)\n",
        "\n",
        "\n",
        "        output = (self.loss_fnt.get_label(logits, num_labels=self.num_labels))\n",
        "        if labels is not None:\n",
        "            labels = [torch.tensor(label) for label in labels]\n",
        "            labels = torch.cat(labels, dim=0).to(logits)\n",
        "            loss = self.loss_fnt(logits.float(), labels.float())\n",
        "            output = (loss.to(sequence_output), output)\n",
        "        return output"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "db4c5c74",
      "metadata": {
        "id": "db4c5c74"
      },
      "source": [
        "## Loss function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "50cf7653",
      "metadata": {
        "id": "50cf7653"
      },
      "outputs": [],
      "source": [
        "def multilabel_categorical_crossentropy(y_true, y_pred):\n",
        "    y_pred = (1 - 2 * y_true) * y_pred\n",
        "    y_pred_neg = y_pred - y_true * 1e30\n",
        "    y_pred_pos = y_pred - (1 - y_true) * 1e30\n",
        "    zeros = torch.zeros_like(y_pred[..., :1])\n",
        "    y_pred_neg = torch.cat([y_pred_neg, zeros],dim=-1)\n",
        "    y_pred_pos = torch.cat((y_pred_pos, zeros),dim=-1)\n",
        "    neg_loss = torch.logsumexp(y_pred_neg, axis=-1)\n",
        "    pos_loss = torch.logsumexp(y_pred_pos, axis=-1)\n",
        "    return neg_loss + pos_loss\n",
        "\n",
        "\n",
        "class ATLoss(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "\n",
        "    def forward(self, logits, labels):\n",
        "\n",
        "        loss = multilabel_categorical_crossentropy(labels,logits)\n",
        "        loss = loss.mean()\n",
        "        return loss\n",
        "\n",
        "    def get_label(self, logits, num_labels=-1):\n",
        "        th_logit = torch.zeros_like(logits[..., :1])\n",
        "        output = torch.zeros_like(logits).to(logits)\n",
        "        mask = (logits > th_logit)\n",
        "        if num_labels > 0:\n",
        "            top_v, _ = torch.topk(logits, num_labels, dim=1)\n",
        "            top_v = top_v[:, -1]\n",
        "            mask = (logits >= top_v.unsqueeze(1)) & mask\n",
        "        output[mask] = 1.0\n",
        "        output[:, 0] = (output[:,1:].sum(1) == 0.).to(logits)\n",
        "\n",
        "        return output"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "be7fed6f",
      "metadata": {
        "id": "be7fed6f"
      },
      "source": [
        "## Preprocess the inputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ab2c84f8",
      "metadata": {
        "id": "ab2c84f8"
      },
      "outputs": [],
      "source": [
        "def process_long_input(model, input_ids, attention_mask, start_tokens, end_tokens):\n",
        "    # Split the input to 2 overlapping chunks. Now BERT can encode inputs of which the length are up to 1024.\n",
        "    n, c = input_ids.size()\n",
        "    start_tokens = torch.tensor(start_tokens).to(input_ids)\n",
        "    end_tokens = torch.tensor(end_tokens).to(input_ids)\n",
        "    len_start = start_tokens.size(0)\n",
        "    len_end = end_tokens.size(0)\n",
        "    if c <= 512:\n",
        "        output = model(\n",
        "            input_ids=input_ids,\n",
        "            attention_mask=attention_mask,\n",
        "            output_attentions=True,\n",
        "        )\n",
        "        sequence_output = output[0]\n",
        "        attention = output[-1][-1]\n",
        "    else:\n",
        "        new_input_ids, new_attention_mask, num_seg = [], [], []\n",
        "        seq_len = attention_mask.sum(1).cpu().numpy().astype(np.int32).tolist()\n",
        "        for i, l_i in enumerate(seq_len):\n",
        "            if l_i <= 512:\n",
        "                new_input_ids.append(input_ids[i, :512])\n",
        "                new_attention_mask.append(attention_mask[i, :512])\n",
        "                num_seg.append(1)\n",
        "            else:\n",
        "                input_ids1 = torch.cat([input_ids[i, :512 - len_end], end_tokens], dim=-1)\n",
        "                input_ids2 = torch.cat([start_tokens, input_ids[i, (l_i - 512 + len_start): l_i]], dim=-1)\n",
        "                attention_mask1 = attention_mask[i, :512]\n",
        "                attention_mask2 = attention_mask[i, (l_i - 512): l_i]\n",
        "                new_input_ids.extend([input_ids1, input_ids2])\n",
        "                new_attention_mask.extend([attention_mask1, attention_mask2])\n",
        "                num_seg.append(2)\n",
        "        input_ids = torch.stack(new_input_ids, dim=0)\n",
        "        attention_mask = torch.stack(new_attention_mask, dim=0)\n",
        "        output = model(\n",
        "            input_ids=input_ids,\n",
        "            attention_mask=attention_mask,\n",
        "            output_attentions=True,\n",
        "        )\n",
        "        sequence_output = output[0]\n",
        "        attention = output[-1][-1]\n",
        "        i = 0\n",
        "        new_output, new_attention = [], []\n",
        "        for (n_s, l_i) in zip(num_seg, seq_len):\n",
        "            if n_s == 1:\n",
        "                output = F.pad(sequence_output[i], (0, 0, 0, c - 512))\n",
        "                att = F.pad(attention[i], (0, c - 512, 0, c - 512))\n",
        "                new_output.append(output)\n",
        "                new_attention.append(att)\n",
        "            elif n_s == 2:\n",
        "                output1 = sequence_output[i][:512 - len_end]\n",
        "                mask1 = attention_mask[i][:512 - len_end]\n",
        "                att1 = attention[i][:, :512 - len_end, :512 - len_end]\n",
        "                output1 = F.pad(output1, (0, 0, 0, c - 512 + len_end))\n",
        "                mask1 = F.pad(mask1, (0, c - 512 + len_end))\n",
        "                att1 = F.pad(att1, (0, c - 512 + len_end, 0, c - 512 + len_end))\n",
        "\n",
        "                output2 = sequence_output[i + 1][len_start:]\n",
        "                mask2 = attention_mask[i + 1][len_start:]\n",
        "                att2 = attention[i + 1][:, len_start:, len_start:]\n",
        "                output2 = F.pad(output2, (0, 0, l_i - 512 + len_start, c - l_i))\n",
        "                mask2 = F.pad(mask2, (l_i - 512 + len_start, c - l_i))\n",
        "                att2 = F.pad(att2, [l_i - 512 + len_start, c - l_i, l_i - 512 + len_start, c - l_i])\n",
        "                mask = mask1 + mask2 + 1e-10\n",
        "                output = (output1 + output2) / mask.unsqueeze(-1)\n",
        "                att = (att1 + att2)\n",
        "                att = att / (att.sum(-1, keepdim=True) + 1e-10)\n",
        "                new_output.append(output)\n",
        "                new_attention.append(att)\n",
        "            i += n_s\n",
        "        sequence_output = torch.stack(new_output, dim=0)\n",
        "        attention = torch.stack(new_attention, dim=0)\n",
        "    return sequence_output, attention"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "bb59e10c",
      "metadata": {
        "id": "bb59e10c"
      },
      "source": [
        "## Auxiliary functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2c46f18d",
      "metadata": {
        "id": "2c46f18d"
      },
      "outputs": [],
      "source": [
        "def set_seed(cfg):\n",
        "    random.seed(cfg.seed)\n",
        "    np.random.seed(cfg.seed)\n",
        "    torch.manual_seed(cfg.seed)\n",
        "\n",
        "def collate_fn(batch):\n",
        "    max_len = max([len(f[\"input_ids\"]) for f in batch])\n",
        "    input_ids = [f[\"input_ids\"] + [0] * (max_len - len(f[\"input_ids\"])) for f in batch]\n",
        "    input_mask = [[1.0] * len(f[\"input_ids\"]) + [0.0] * (max_len - len(f[\"input_ids\"])) for f in batch]\n",
        "    input_ids = torch.tensor(input_ids, dtype=torch.long)\n",
        "    input_mask = torch.tensor(input_mask, dtype=torch.float)\n",
        "    entity_pos = [f[\"entity_pos\"] for f in batch]\n",
        "\n",
        "    labels = [f[\"labels\"] for f in batch]\n",
        "    hts = [f[\"hts\"] for f in batch]\n",
        "    output = (input_ids, input_mask, labels, entity_pos, hts )\n",
        "    return output\n",
        "\n",
        "def to_official(args, preds, features):\n",
        "    rel2id = json.load(open(f'{args.data_dir}/rel2id.json', 'r'))\n",
        "    id2rel = {value: key for key, value in rel2id.items()}\n",
        "    \n",
        "    h_idx, t_idx, title = [], [], []\n",
        "\n",
        "    for f in features:\n",
        "        hts = f[\"hts\"]\n",
        "        h_idx += [ht[0] for ht in hts]\n",
        "        t_idx += [ht[1] for ht in hts]\n",
        "        title += [f[\"title\"] for ht in hts]\n",
        "\n",
        "    res = []\n",
        "    # print('h_idx, preds', len(h_idx), len(preds))\n",
        "    # assert len(h_idx) == len(preds)\n",
        "\n",
        "\n",
        "    for i in range(preds.shape[0]):\n",
        "        pred = preds[i]\n",
        "        pred = np.nonzero(pred)[0].tolist()\n",
        "        for p in pred:\n",
        "            if p != 0:\n",
        "                res.append(\n",
        "                    {\n",
        "                        'title': title[i],\n",
        "                        'h_idx': h_idx[i],\n",
        "                        't_idx': t_idx[i],\n",
        "                        'r': id2rel[p],\n",
        "                    }\n",
        "                )\n",
        "    return res\n",
        "\n",
        "def gen_train_facts(data_file_name, truth_dir):\n",
        "    fact_file_name = data_file_name[data_file_name.find(\"train_\"):]\n",
        "    fact_file_name = os.path.join(truth_dir, fact_file_name.replace(\".json\", \".fact\"))\n",
        "\n",
        "    if os.path.exists(fact_file_name):\n",
        "        fact_in_train = set([])\n",
        "        triples = json.load(open(fact_file_name))\n",
        "        for x in triples:\n",
        "            fact_in_train.add(tuple(x))\n",
        "        return fact_in_train\n",
        "\n",
        "    fact_in_train = set([])\n",
        "    ori_data = json.load(open(data_file_name))\n",
        "    for data in ori_data:\n",
        "        vertexSet = data['vertexSet']\n",
        "        for label in data['labels']:\n",
        "            rel = label['r']\n",
        "            for n1 in vertexSet[label['h']]:\n",
        "                for n2 in vertexSet[label['t']]:\n",
        "                    fact_in_train.add((n1['name'], n2['name'], rel))\n",
        "\n",
        "    json.dump(list(fact_in_train), open(fact_file_name, \"w\"))\n",
        "\n",
        "    return fact_in_train\n",
        "\n",
        "def official_evaluate(tmp, path):\n",
        "    '''\n",
        "        Adapted from the official evaluation code\n",
        "    '''\n",
        "    truth_dir = os.path.join(path, 'ref')\n",
        "\n",
        "    if not os.path.exists(truth_dir):\n",
        "        os.makedirs(truth_dir)\n",
        "\n",
        "    fact_in_train_annotated = gen_train_facts(os.path.join(path, \"train_annotated.json\"), truth_dir)\n",
        "\n",
        "    if not os.path.exists(os.path.join(path, \"train_distant.json\")):\n",
        "        raise FileNotFoundError(\"Sorry, the file: 'train_annotated.json' is too big to upload to github, \\\n",
        "            please manually download to 'data/' from DocRED GoogleDrive https://drive.google.com/drive/folders/1c5-0YwnoJx8NS6CV2f-NoTHR__BdkNqw\")\n",
        "    fact_in_train_distant = gen_train_facts(os.path.join(path, \"train_distant.json\"), truth_dir)\n",
        "\n",
        "    truth = json.load(open(os.path.join(path, \"dev.json\")))\n",
        "\n",
        "    std = {}\n",
        "    tot_evidences = 0\n",
        "    titleset = set([])\n",
        "\n",
        "    title2vectexSet = {}\n",
        "\n",
        "    for x in truth:\n",
        "        title = x['title']\n",
        "        titleset.add(title)\n",
        "\n",
        "        vertexSet = x['vertexSet']\n",
        "        title2vectexSet[title] = vertexSet\n",
        "\n",
        "        for label in x['labels']:\n",
        "            r = label['r']\n",
        "            h_idx = label['h']\n",
        "            t_idx = label['t']\n",
        "            std[(title, r, h_idx, t_idx)] = set(label['evidence'])\n",
        "            tot_evidences += len(label['evidence'])\n",
        "\n",
        "    tot_relations = len(std)\n",
        "    tmp.sort(key=lambda x: (x['title'], x['h_idx'], x['t_idx'], x['r']))\n",
        "    submission_answer = [tmp[0]]\n",
        "    for i in range(1, len(tmp)):\n",
        "        x = tmp[i]\n",
        "        y = tmp[i - 1]\n",
        "        if (x['title'], x['h_idx'], x['t_idx'], x['r']) != (y['title'], y['h_idx'], y['t_idx'], y['r']):\n",
        "            submission_answer.append(tmp[i])\n",
        "\n",
        "    correct_re = 0\n",
        "    correct_evidence = 0\n",
        "    pred_evi = 0\n",
        "\n",
        "    correct_in_train_annotated = 0\n",
        "    correct_in_train_distant = 0\n",
        "    titleset2 = set([])\n",
        "    for x in submission_answer:\n",
        "        title = x['title']\n",
        "        h_idx = x['h_idx']\n",
        "        t_idx = x['t_idx']\n",
        "        r = x['r']\n",
        "        titleset2.add(title)\n",
        "        if title not in title2vectexSet:\n",
        "            continue\n",
        "        vertexSet = title2vectexSet[title]\n",
        "\n",
        "        if 'evidence' in x:\n",
        "            evi = set(x['evidence'])\n",
        "        else:\n",
        "            evi = set([])\n",
        "        pred_evi += len(evi)\n",
        "\n",
        "        if (title, r, h_idx, t_idx) in std:\n",
        "            correct_re += 1\n",
        "            stdevi = std[(title, r, h_idx, t_idx)]\n",
        "            correct_evidence += len(stdevi & evi)\n",
        "            in_train_annotated = in_train_distant = False\n",
        "            for n1 in vertexSet[h_idx]:\n",
        "                for n2 in vertexSet[t_idx]:\n",
        "                    if (n1['name'], n2['name'], r) in fact_in_train_annotated:\n",
        "                        in_train_annotated = True\n",
        "                    if (n1['name'], n2['name'], r) in fact_in_train_distant:\n",
        "                        in_train_distant = True\n",
        "\n",
        "            if in_train_annotated:\n",
        "                correct_in_train_annotated += 1\n",
        "            if in_train_distant:\n",
        "                correct_in_train_distant += 1\n",
        "\n",
        "    re_p = 1.0 * correct_re / len(submission_answer)\n",
        "    re_r = 1.0 * correct_re / tot_relations\n",
        "    if re_p + re_r == 0:\n",
        "        re_f1 = 0\n",
        "    else:\n",
        "        re_f1 = 2.0 * re_p * re_r / (re_p + re_r)\n",
        "\n",
        "    evi_p = 1.0 * correct_evidence / pred_evi if pred_evi > 0 else 0\n",
        "    evi_r = 1.0 * correct_evidence / tot_evidences\n",
        "    if evi_p + evi_r == 0:\n",
        "        evi_f1 = 0\n",
        "    else:\n",
        "        evi_f1 = 2.0 * evi_p * evi_r / (evi_p + evi_r)\n",
        "\n",
        "    re_p_ignore_train_annotated = 1.0 * (correct_re - correct_in_train_annotated) / (len(submission_answer) - correct_in_train_annotated + 1e-5)\n",
        "    re_p_ignore_train = 1.0 * (correct_re - correct_in_train_distant) / (len(submission_answer) - correct_in_train_distant + 1e-5)\n",
        "\n",
        "    if re_p_ignore_train_annotated + re_r == 0:\n",
        "        re_f1_ignore_train_annotated = 0\n",
        "    else:\n",
        "        re_f1_ignore_train_annotated = 2.0 * re_p_ignore_train_annotated * re_r / (re_p_ignore_train_annotated + re_r)\n",
        "\n",
        "    if re_p_ignore_train + re_r == 0:\n",
        "        re_f1_ignore_train = 0\n",
        "    else:\n",
        "        re_f1_ignore_train = 2.0 * re_p_ignore_train * re_r / (re_p_ignore_train + re_r)\n",
        "\n",
        "    return re_f1, evi_f1, re_f1_ignore_train_annotated, re_f1_ignore_train, re_p, re_r"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "bd91859a",
      "metadata": {
        "id": "bd91859a"
      },
      "source": [
        "## Train the model\n",
        "### Config parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5c8b82c7",
      "metadata": {
        "id": "5c8b82c7"
      },
      "outputs": [],
      "source": [
        "class Config(object):\n",
        "    adam_epsilon=1e-06\n",
        "    bert_lr=3e-05\n",
        "    channel_type='context-based'\n",
        "    config_name=''\n",
        "    data_dir='./data'\n",
        "    dataset='docred'\n",
        "    dev_file='dev.json'\n",
        "    down_dim=256\n",
        "    evaluation_steps=-1\n",
        "    gradient_accumulation_steps=2\n",
        "    learning_rate=0.0004\n",
        "    log_dir='./train_roberta.log'\n",
        "    max_grad_norm=1.0\n",
        "    max_height=42\n",
        "    max_seq_length=1024\n",
        "    model_name_or_path='roberta-base'\n",
        "    num_class=97\n",
        "    num_labels=4\n",
        "    num_train_epochs=30\n",
        "    save_path='./model_roberta.pt'\n",
        "    seed=111\n",
        "    test_batch_size=2\n",
        "    test_file='test.json'\n",
        "    tokenizer_name=''\n",
        "    train_batch_size=2\n",
        "    train_file='train_annotated.json'\n",
        "    train_from_saved_model=''\n",
        "    transformer_type='roberta'\n",
        "    unet_in_dim=3\n",
        "    unet_out_dim=256\n",
        "    warmup_ratio=0.06\n",
        "    load_path='./model_roberta.pt'\n",
        "    \n",
        "cfg = Config()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "92647af6",
      "metadata": {
        "id": "92647af6"
      },
      "source": [
        "### Model Training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "48b9679a",
      "metadata": {
        "id": "48b9679a"
      },
      "outputs": [],
      "source": [
        "def train(args, model, train_features, dev_features, test_features):\n",
        "    def logging(s, print_=True, log_=True):\n",
        "        if print_:\n",
        "            print(s)\n",
        "        if log_ and args.log_dir != '':\n",
        "            with open(args.log_dir, 'a+') as f_log:\n",
        "                f_log.write(s + '\\n')\n",
        "    def finetune(features, optimizer, num_epoch, num_steps, model):\n",
        "        cur_model = model.module if hasattr(model, 'module') else model\n",
        "        device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "        if args.train_from_saved_model != '':\n",
        "            best_score = torch.load(args.train_from_saved_model)[\"best_f1\"]\n",
        "            epoch_delta = torch.load(args.train_from_saved_model)[\"epoch\"] + 1\n",
        "        else:\n",
        "            epoch_delta = 0\n",
        "            best_score = -1\n",
        "        train_dataloader = DataLoader(features, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)\n",
        "        train_iterator = [epoch + epoch_delta for epoch in range(num_epoch)]\n",
        "        total_steps = int(len(train_dataloader) * num_epoch // args.gradient_accumulation_steps)\n",
        "        warmup_steps = int(total_steps * args.warmup_ratio)\n",
        "        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)\n",
        "        print(\"Total steps: {}\".format(total_steps))\n",
        "        print(\"Warmup steps: {}\".format(warmup_steps))\n",
        "        global_step = 0\n",
        "        log_step = 100\n",
        "        total_loss = 0\n",
        "        \n",
        "\n",
        "\n",
        "        #scaler = GradScaler()\n",
        "        for epoch in train_iterator:\n",
        "            start_time = time.time()\n",
        "            optimizer.zero_grad()\n",
        "\n",
        "            for step, batch in enumerate(train_dataloader):\n",
        "                model.train()\n",
        "\n",
        "                inputs = {'input_ids': batch[0].to(device),\n",
        "                          'attention_mask': batch[1].to(device),\n",
        "                          'labels': batch[2],\n",
        "                          'entity_pos': batch[3],\n",
        "                          'hts': batch[4],\n",
        "                          }\n",
        "                #with autocast():\n",
        "                outputs = model(**inputs)\n",
        "                loss = outputs[0] / args.gradient_accumulation_steps\n",
        "                total_loss += loss.item()\n",
        "                #    scaler.scale(loss).backward()\n",
        "               \n",
        "\n",
        "                loss.backward()\n",
        "\n",
        "                if step % args.gradient_accumulation_steps == 0:\n",
        "                    #scaler.unscale_(optimizer)\n",
        "                    if args.max_grad_norm > 0:\n",
        "                        # torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)\n",
        "                        torch.nn.utils.clip_grad_norm_(cur_model.parameters(), args.max_grad_norm)\n",
        "                    #scaler.step(optimizer)\n",
        "                    #scaler.update()\n",
        "                    #scheduler.step()\n",
        "                    optimizer.step()\n",
        "                    scheduler.step()\n",
        "                    optimizer.zero_grad()\n",
        "                    global_step += 1\n",
        "                    num_steps += 1\n",
        "                    if global_step % log_step == 0:\n",
        "                        cur_loss = total_loss / log_step\n",
        "                        elapsed = time.time() - start_time\n",
        "                        logging(\n",
        "                            '| epoch {:2d} | step {:4d} | min/b {:5.2f} | lr {} | train loss {:5.3f}'.format(\n",
        "                                epoch, global_step, elapsed / 60, scheduler.get_last_lr(), cur_loss))\n",
        "                        total_loss = 0\n",
        "                        start_time = time.time()\n",
        "\n",
        "                if (step + 1) == len(train_dataloader) - 1 or (args.evaluation_steps > 0 and num_steps % args.evaluation_steps == 0 and step % args.gradient_accumulation_steps == 0):\n",
        "                # if step ==0:\n",
        "                    logging('-' * 89)\n",
        "                    eval_start_time = time.time()\n",
        "                    dev_score, dev_output = evaluate(args, model, dev_features, tag=\"dev\")\n",
        "\n",
        "                    logging(\n",
        "                        '| epoch {:3d} | time: {:5.2f}s | dev_result:{}'.format(epoch, time.time() - eval_start_time,\n",
        "                                                                                dev_output))\n",
        "                    logging('-' * 89)\n",
        "                    if dev_score > best_score:\n",
        "                        best_score = dev_score\n",
        "                        logging(\n",
        "                            '| epoch {:3d} | best_f1:{}'.format(epoch, best_score))\n",
        "                        if args.save_path != \"\":\n",
        "                            torch.save({\n",
        "                                'epoch': epoch,\n",
        "                                'checkpoint': cur_model.state_dict(),\n",
        "                                'best_f1': best_score,\n",
        "                                'optimizer': optimizer.state_dict()\n",
        "                            }, args.save_path\n",
        "                            , _use_new_zipfile_serialization=False)\n",
        "                            logging(\n",
        "                                '| successfully save model at: {}'.format(args.save_path))\n",
        "                            logging('-' * 89)\n",
        "        return num_steps\n",
        "\n",
        "    cur_model = model.module if hasattr(model, 'module') else model\n",
        "    extract_layer = [\"extractor\", \"bilinear\"]\n",
        "    bert_layer = ['bert_model']\n",
        "    optimizer_grouped_parameters = [\n",
        "        {\"params\": [p for n, p in cur_model.named_parameters() if any(nd in n for nd in bert_layer)], \"lr\": args.bert_lr},\n",
        "        {\"params\": [p for n, p in cur_model.named_parameters() if any(nd in n for nd in extract_layer)], \"lr\": 1e-4},\n",
        "        {\"params\": [p for n, p in cur_model.named_parameters() if not any(nd in n for nd in extract_layer + bert_layer)]},\n",
        "    ]\n",
        "\n",
        "    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)\n",
        "    if args.train_from_saved_model != '':\n",
        "        optimizer.load_state_dict(torch.load(args.train_from_saved_model)[\"optimizer\"])\n",
        "        print(\"load saved optimizer from {}.\".format(args.train_from_saved_model))\n",
        "    \n",
        "\n",
        "    num_steps = 0\n",
        "    set_seed(args)\n",
        "    model.zero_grad()\n",
        "    finetune(train_features, optimizer, args.num_train_epochs, num_steps, model)\n",
        "\n",
        "def evaluate(args, model, features, tag=\"dev\"):\n",
        "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "    dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)\n",
        "    preds = []\n",
        "    total_loss = 0\n",
        "    for i, batch in enumerate(dataloader):\n",
        "        model.eval()\n",
        "\n",
        "        inputs = {'input_ids': batch[0].to(device),\n",
        "                  'attention_mask': batch[1].to(device),\n",
        "                  'labels': batch[2],\n",
        "                  'entity_pos': batch[3],\n",
        "                  'hts': batch[4],\n",
        "                  }\n",
        "\n",
        "        with torch.no_grad():\n",
        "            output = model(**inputs)\n",
        "            loss = output[0]\n",
        "            pred = output[1].cpu().numpy()\n",
        "            pred[np.isnan(pred)] = 0\n",
        "            preds.append(pred)\n",
        "            total_loss += loss.item()\n",
        "\n",
        "    average_loss = total_loss / (i + 1)\n",
        "    preds = np.concatenate(preds, axis=0).astype(np.float32)\n",
        "    ans = to_official(args, preds, features)\n",
        "    if len(ans) > 0:\n",
        "        best_f1, _, best_f1_ign, _, re_p, re_r = official_evaluate(ans, args.data_dir)\n",
        "    output = {\n",
        "        tag + \"_F1\": best_f1 * 100,\n",
        "        tag + \"_F1_ign\": best_f1_ign * 100,\n",
        "        tag + \"_re_p\": re_p * 100,\n",
        "        tag + \"_re_r\": re_r * 100,\n",
        "        tag + \"_average_loss\": average_loss\n",
        "    }\n",
        "    return best_f1, output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "026c8be0",
      "metadata": {
        "id": "026c8be0"
      },
      "outputs": [],
      "source": [
        "if not os.path.exists(os.path.join(cfg.data_dir, \"train_distant.json\")):\n",
        "    raise FileNotFoundError(\"Sorry, the file: 'train_annotated.json' is too big to upload to github, \\\n",
        "        please manually download to 'data/' from DocRED GoogleDrive https://drive.google.com/drive/folders/1c5-0YwnoJx8NS6CV2f-NoTHR__BdkNqw\")\n",
        "\n",
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "config = AutoConfig.from_pretrained(\n",
        "    cfg.config_name if cfg.config_name else cfg.model_name_or_path,\n",
        "    num_labels=cfg.num_class,\n",
        ")\n",
        "tokenizer = AutoTokenizer.from_pretrained(\n",
        "    cfg.tokenizer_name if cfg.tokenizer_name else cfg.model_name_or_path,\n",
        ")\n",
        "\n",
        "Dataset = ReadDataset(cfg.dataset, tokenizer, cfg.max_seq_length, cfg.transformer_type)\n",
        "\n",
        "train_file = os.path.join(cfg.data_dir, cfg.train_file)\n",
        "dev_file = os.path.join(cfg.data_dir, cfg.dev_file)\n",
        "test_file = os.path.join(cfg.data_dir, cfg.test_file)\n",
        "train_features = Dataset.read(train_file)\n",
        "dev_features = Dataset.read(dev_file)\n",
        "test_features = Dataset.read(test_file)\n",
        "\n",
        "model = AutoModel.from_pretrained(\n",
        "    cfg.model_name_or_path,\n",
        "    from_tf=bool(\".ckpt\" in cfg.model_name_or_path),\n",
        "    config=config,\n",
        ")\n",
        "\n",
        "\n",
        "config.cls_token_id = tokenizer.cls_token_id\n",
        "config.sep_token_id = tokenizer.sep_token_id\n",
        "config.transformer_type = cfg.transformer_type\n",
        "\n",
        "set_seed(cfg)\n",
        "model = DocREModel(config, cfg,  model, num_labels=cfg.num_labels)\n",
        "if cfg.train_from_saved_model != '':\n",
        "    model.load_state_dict(torch.load(cfg.train_from_saved_model)[\"checkpoint\"])\n",
        "    print(\"load saved model from {}.\".format(cfg.train_from_saved_model))\n",
        "\n",
        "#if torch.cuda.device_count() > 1:\n",
        "#    print(\"Let's use\", torch.cuda.device_count(), \"GPUs!\")\n",
        "#    model = torch.nn.DataParallel(model, device_ids = list(range(torch.cuda.device_count())))\n",
        "model.to(device)\n",
        "\n",
        "train(cfg, model, train_features, dev_features, test_features)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "14c983c7",
      "metadata": {
        "id": "14c983c7"
      },
      "source": [
        "### Model Prediction"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cbc55160",
      "metadata": {
        "id": "cbc55160"
      },
      "outputs": [],
      "source": [
        "def report(args, model, features):\n",
        "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "    dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)\n",
        "    preds = []\n",
        "    for batch in dataloader:\n",
        "        model.eval()\n",
        "\n",
        "        inputs = {'input_ids': batch[0].to(device),\n",
        "                  'attention_mask': batch[1].to(device),\n",
        "                  'entity_pos': batch[3],\n",
        "                  'hts': batch[4],\n",
        "                  }\n",
        "\n",
        "        with torch.no_grad():\n",
        "            pred = model(**inputs)\n",
        "            pred = pred.cpu().numpy()\n",
        "            pred[np.isnan(pred)] = 0\n",
        "            preds.append(pred)\n",
        "\n",
        "    preds = np.concatenate(preds, axis=0).astype(np.float32)\n",
        "    preds = to_official(args, preds, features)\n",
        "    return preds\n",
        "\n",
        "model.load_state_dict(torch.load(cfg.load_path)['checkpoint'])\n",
        "T_features = test_features  # Testing on the test set\n",
        "#T_score, T_output = evaluate(cfg, model, T_features, tag=\"test\")\n",
        "pred = report(cfg, model, T_features)\n",
        "with open(\"./result.json\", \"w\") as fh:\n",
        "    json.dump(pred, fh)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.11"
    },
    "colab": {
      "name": "document_re_tutorial.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "private_outputs": true
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
